package org.zjvis.datascience.service;

import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.Statement;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.regex.Pattern;

import com.alibaba.fastjson.JSON;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.zjvis.datascience.common.enums.ETLEnum;
import org.zjvis.datascience.common.sql.SqlHelper;
import org.zjvis.datascience.common.util.SqlUtil;
import org.zjvis.datascience.common.util.ToolUtil;
import org.zjvis.datascience.common.util.db.JDBCUtil;
import org.zjvis.datascience.common.vo.TaskVO;
import org.zjvis.datascience.service.dataprovider.GPDataProvider;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.mayabot.nlp.fasttext.FastText;

/**
 * @description 自动Join Union操作 Service
 * @date 2021-11-19
 */
@Service
public class AutoJoinService {
    private final static Logger logger = LoggerFactory.getLogger(AutoJoinService.class);

    protected final static int SAMPLE_NUMBER = 5000;

    private final static float SCORE_WEIGHT = 0.3f;

    private final static int NUM_BIN = 10;

    @Autowired
    private FastTextService fastTextService;

    @Autowired
    private TaskInstanceService taskInstanceService;

    @Autowired
    private GPDataProvider gpDataProvider;

    @Lazy
    @Autowired
    private TaskService taskService;

    public JSONObject autoJoinRecommend(JSONObject conf, List<String> parentIds) {
        return autoJoinRecommend(conf, parentIds, -1L);
    }

    public JSONObject autoJoinRecommend(JSONObject conf, List<String> parentIds, Long taskId) {
        JSONArray result = new JSONArray();
        try {
            JSONArray input = conf.getJSONArray("input");
            List<String> tableNames = new ArrayList<>();
            if (taskId < 0) {
                tableNames = taskInstanceService.getTableName(input, parentIds);
            } else {
                TaskVO taskVO = taskService.queryFullInputById(taskId);
                input = taskVO.getData().getJSONArray("input");
                for (int i = 0; i < input.size(); ++i) {
                    tableNames.add(input.getJSONObject(i).getString("tableName"));
                }
            }
            String leftTableName = tableNames.get(0);
            String rightTableName = tableNames.get(1);

            JSONObject leftSemantics = JSONObject
                .parseObject(input.getJSONObject(0).getString("semantic"));
            JSONObject rightSemantics = JSONObject
                .parseObject(input.getJSONObject(1).getString("semantic"));

            result = autoJoinRecommend(leftTableName, rightTableName, leftSemantics, rightSemantics, fastTextService.getAutojoinModel());

        } catch (Exception e){
            logger.error(e.getMessage());
        }
        if (ETLEnum.JOIN.getVal() == conf.getIntValue("algType")){
            conf.put("autoJoinRecommendation", result);
        }else if(ETLEnum.UNION.getVal() == conf.getIntValue("algType")){
            conf.put("autoUnionRecommendation", result);
        }
        return conf;
    }

    public JSONArray autoJoinRecommend(String leftTableName, String rightTableName,
        JSONObject leftSemantics, JSONObject rightSemantics, FastText model) {
        JSONArray recommendations = autoJoinRecommend("*", "*", leftTableName, rightTableName,
            leftSemantics, rightSemantics, model, 0);
        JSONArray ret = new JSONArray();
        for (int i = 0; i < Math.min(20, recommendations.size()); i++) {
            JSONObject recommendation = recommendations.getJSONObject(i);
            if (recommendations.size() > 10 && recommendation.getDouble("score") <= 0.1) {
                continue;
            }
            recommendation.put("id", i);
            ret.add(recommendation);
            logger
                .info(String.format("Top %d: total score: %f, header score: %f, value score: %f, " +
                        "header name '%s' from table '%s', header name '%s' from table '%s'", i + 1,
                    recommendation.getDouble("score"),
                    recommendation.getDouble("headerScore"), recommendation.getDouble("valueScore"),
                    recommendation.getString("leftHeaderName"),
                    recommendation.getString("leftTableName"),
                    recommendation.getString("rightHeaderName"),
                    recommendation.getString("rightTableName")));
        }
        return ret;
    }

    public JSONArray autoJoinRecommend(String leftCols, String rightCols,
        String leftTableName,
        String rightTableName, JSONObject leftSemantics, JSONObject rightSemantics,
        FastText model, double minScore) {
        List<JSONObject> autoJoinScores = new ArrayList<>();
        double maxHeaderSimilarity = 0d;
        double maxValueSimilarity = 0d;
        double similarityScore = 0d;

        JSONArray leftMeta = getMetaForAutoJoin(leftCols, leftTableName, leftSemantics, model);
        JSONArray rightMeta = getMetaForAutoJoin(rightCols, rightTableName, rightSemantics, model);

        for (int i = 0; i < leftMeta.size(); i++) {
            int leftType = (int) leftMeta.getJSONObject(i).get("type");
            String leftHeaderName = (String) leftMeta.getJSONObject(i).get("name");
            String leftSemantic = (String) leftMeta.getJSONObject(i).get("semantic");
            double[] leftNameVector = (double[]) leftMeta.getJSONObject(i).get("nameVector");
            double[] leftValueVector = (double[]) leftMeta.getJSONObject(i).get("valueVector");
            for (int j = 0; j < rightMeta.size(); j++) {
                int rightType = (int) rightMeta.getJSONObject(j).get("type");
                String rightHeaderName = (String) rightMeta.getJSONObject(j).get("name");
                String rightSemantic = (String) rightMeta.getJSONObject(j).get("semantic");
                double[] rightNameVector = (double[]) rightMeta.getJSONObject(j).get("nameVector");
                double[] rightValueVector = (double[]) rightMeta.getJSONObject(j)
                    .get("valueVector");
                if (leftType == rightType && leftSemantic.equals(rightSemantic)) {
                    if (leftHeaderName.equalsIgnoreCase(rightHeaderName)) {
                        maxHeaderSimilarity = 1;
                        maxValueSimilarity = 1;
                    } else {
                        maxHeaderSimilarity = fastTextService
                            .computeVectorSimilarity(leftNameVector, rightNameVector);
                        maxValueSimilarity = fastTextService
                            .computeVectorSimilarity(leftValueVector, rightValueVector);
                    }
                    maxHeaderSimilarity = SCORE_WEIGHT * maxHeaderSimilarity;
                    if (leftType == Types.VARCHAR) {
                        maxValueSimilarity = Math.min(1, maxValueSimilarity + 0.1);
                    }
                    maxValueSimilarity = (1f - SCORE_WEIGHT) * maxValueSimilarity;
                    similarityScore = maxHeaderSimilarity + maxValueSimilarity;
                } else {
                    similarityScore = 0;
                }
                if (similarityScore > minScore) {
                    JSONObject autoJoinScore = new JSONObject();
                    autoJoinScore.put("score", similarityScore);
                    autoJoinScore.put("headerScore", maxHeaderSimilarity);
                    autoJoinScore.put("valueScore", maxValueSimilarity);
                    autoJoinScore.put("leftHeaderName", leftHeaderName);
                    autoJoinScore.put("leftTableName", leftTableName);
                    autoJoinScore.put("rightHeaderName", rightHeaderName);
                    autoJoinScore.put("rightTableName", rightTableName);
                    autoJoinScore.put("operator", "=");
                    autoJoinScores.add(autoJoinScore);
                }
            }
        }

        JSONArray recommendations = sortJsonArray(autoJoinScores, "score");
        return recommendations;
    }

    public JSONArray joinRecommend(JSONObject param) {
        JSONArray ret = new JSONArray();
        Long taskId = param.getLong("taskId");
        String side = param.getString("side");
        String col = param.getString("col");

        List<String> tableNames = new ArrayList<>();
        TaskVO taskVO = taskService.queryFullInputById(taskId);
        JSONArray input = taskVO.getData().getJSONArray("input");
        for (int i = 0; i < input.size(); ++i) {
            tableNames.add(input.getJSONObject(i).getString("tableName"));
        }
        String leftTableName = tableNames.get(0);
        String rightTableName = tableNames.get(1);

        JSONObject leftSemantics = JSONObject
            .parseObject(input.getJSONObject(0).getString("semantic"));
        JSONObject rightSemantics = JSONObject
            .parseObject(input.getJSONObject(1).getString("semantic"));
        String leftCols = "*";
        String rightCols = "*";
        if (side.equals("left")) {
            leftCols = col;
        } else {
            rightCols = col;
        }

        JSONArray recommendations = autoJoinRecommend(leftCols, rightCols, leftTableName,
            rightTableName, leftSemantics, rightSemantics, fastTextService.getAutojoinModel(), -1);
        for (Object obj : recommendations) {
            JSONObject recommend = (JSONObject) obj;
            if (side.equals("left")) {
                ret.add(recommend.getString("rightHeaderName"));
            } else {
                ret.add(recommend.getString("leftHeaderName"));
            }
        }
        return ret;
    }

    private double[] calHist(List<String> data, int numBin) {
        double min = Double.parseDouble(data.get(0));
        double max = Double.parseDouble(data.get(0));
        int len = data.size();

        for (String tmp : data) {
            double d = Double.parseDouble(tmp);
            if (d < min) {
                min = d;
            }
            if (d > max) {
                max = d;
            }
        }

        double binWidth = (max - min) / numBin;
        double[] bins = new double[numBin];
        for (String tmp : data) {
            double d = Double.parseDouble(tmp);
            if (Double.isNaN(d)) {
                len -= 1;
                continue;
            }
            int binIndex = (int) ((d - min) / binWidth);
            if (binIndex == numBin) {
                binIndex -= 1;
            }
            bins[binIndex] += 1;
        }

        for (int i = 0; i < numBin; i++) {
            bins[i] /= len;
        }
        return bins;
    }

    private JSONArray sortJsonArray(List<JSONObject> json, String key) {
        JSONArray sortedJsonArray = new JSONArray();
        Collections.sort(json, new Comparator<JSONObject>() {
            @Override
            public int compare(JSONObject a, JSONObject b) {
                double valA = Double.valueOf(a.getString(key));
                double valB = Double.valueOf(b.getString(key));
                double diff = valA - valB;
                if (diff > 0) {
                    return -1;
                } else if (diff < 0) {
                    return 1;
                } else {
                    return 0;
                }
            }
        });
        for (int i = 0; i < json.size(); i++) {
            sortedJsonArray.add(json.get(i));
        }
        return sortedJsonArray;
    }

    private JSONArray getMetaForAutoJoin(String cols, String tableName, JSONObject semantics,
        FastText model) {
        JSONArray ret = new JSONArray();

        String selectSql = String
            .format("select %s from %s limit %s", cols, tableName, SAMPLE_NUMBER);
        Connection conn = null;
        try {
            conn = gpDataProvider.getConn(1L);
            Statement st = conn.createStatement();
            ResultSet rs = st.executeQuery(selectSql);
            ResultSetMetaData meta = rs.getMetaData();

            List<String> headerNames = new ArrayList<>();
            JSONArray headers = new JSONArray();
            for (int i = 1; i < meta.getColumnCount() + 1; i++) {
                JSONObject header = new JSONObject();
                String name = meta.getColumnName(i);
                headerNames.add(name);
                if (name.equals("_record_id_")) {
                    continue;
                }
                header.put("name", meta.getColumnName(i));
                header.put("type", meta.getColumnType(i));
                header.put("desc", SqlUtil.changeType(meta.getColumnTypeName(i)));
                headers.add(header);
            }

            JSONArray values = new JSONArray();
            while (rs.next()) {
                JSONObject column = new JSONObject();
                for (String header : headerNames) {
                    column.put(header, rs.getString(header));
                }
                values.add(column);
            }
//            long countTable = values.size();

            for (int i = 0; i < headers.size(); i++) {
                JSONObject data = new JSONObject();

                String name = headers.getJSONObject(i).getString("name");

//                String groupSql = String.format("select count(%s) as all, count(distinct %s) as distinct from %s", name, name, tableName);
//                rs = st.executeQuery(groupSql);
//                rs.next();
//                long countAll = rs.getLong("all");
//                long countGroup = rs.getLong("distinct");
//                double ratio = 1.0 * countGroup / countTable;
//                if (ratio < 0.02 && countAll > 1000000) {//大数据量时不推荐枚举类以免join时间太长
//                    continue;
//                }

                String semantic = semantics.getString(name);
                if (semantic == null) {
                    semantic = "null";
                }

                int type = headers.getJSONObject(i).getInteger("type");
                type = SqlHelper.mergeSqlType(type);
                if (type == Types.INTEGER || type == Types.DECIMAL) {
                    type = Types.NUMERIC;
                }

                String newName = ToolUtil.standardString(name, ' ');
                double[] nameVector = fastTextService.getTextVector(model, newName);

                double[] valueVector = new double[]{1.0};
                if (semantic.equals("null")) {
                    if (type == Types.VARCHAR) {
                        List<String> groupStrs = gpDataProvider.getDataFromJSON(values, name, 5);
                        String value = "";
                        for (String s : groupStrs) {
                            value += " " + s;
                        }
                        String newValue = ToolUtil.standardString(value, ' ');
                        valueVector = fastTextService.getTextVector(model, newValue);
                    } else if (type == Types.NUMERIC) {
                        List<String> calData = gpDataProvider
                            .getDataFromJSON(values, name, SAMPLE_NUMBER);
                        valueVector = calHist(calData, NUM_BIN);
                    }
                }
                data.put("name", name);
                data.put("semantic", semantic);
                data.put("type", type);
                data.put("nameVector", nameVector);
                data.put("valueVector", valueVector);
                ret.add(data);
            }
        } catch (Exception e) {
            logger.error("TaskService error, errMsg={}", e.getMessage());
        } finally {
            JDBCUtil.close(conn, null, null);
        }
        return ret;
    }
}
